[Text Pipeline] Implement Embedding Connector#338
[Text Pipeline] Implement Embedding Connector#338syhuang22 wants to merge 1 commit intoAI-Hypercomputer:mainfrom
Conversation
|
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
099551b to
7747a33
Compare
| rngs=rngs, | ||
| ) | ||
| self.ff = NNXSimpleFeedForward(rngs=rngs, dim=dim, dim_out=dim) | ||
| self.norm1 = nnx.RMSNorm(dim, epsilon=1e-6, dtype=jnp.float32, param_dtype=jnp.float32, use_scale=True, rngs=rngs) |
There was a problem hiding this comment.
Diffusers uses elementwise_affine = False
https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/ltx2/connectors.py#L112
We should set use_scale = False
| ) | ||
| self.ff = NNXSimpleFeedForward(rngs=rngs, dim=dim, dim_out=dim) | ||
| self.norm1 = nnx.RMSNorm(dim, epsilon=1e-6, dtype=jnp.float32, param_dtype=jnp.float32, use_scale=True, rngs=rngs) | ||
| self.norm2 = nnx.RMSNorm(dim, epsilon=1e-6, dtype=jnp.float32, param_dtype=jnp.float32, use_scale=True, rngs=rngs) |
| ) | ||
|
|
||
| self.final_norm = nnx.RMSNorm( | ||
| self.dim, epsilon=1e-6, dtype=jnp.float32, param_dtype=jnp.float32, use_scale=True, rngs=rngs |
There was a problem hiding this comment.
ecf5514 to
ede2cdf
Compare
| rotary_emb: Optional[Tuple[Array, Array]] = None, | ||
| ) -> Array: | ||
| # 1. Norm -> Attention | ||
| normed = self.norm1(hidden_states) |
There was a problem hiding this comment.
does this need to be casted to input dtype?
There was a problem hiding this comment.
I've added .astype(hidden_states.dtype) right after the norm to cast it back
| hidden_states = hidden_states + attn_output | ||
|
|
||
| # 2. Norm -> FeedForward | ||
| normed = self.norm2(hidden_states) |
Signed-off-by: James Huang <syhuang1201@gmail.com>
ede2cdf to
ca99d98
Compare
This module acts as the crucial bridge in the LTX-2 text pipeline, responsible for processing and aligning the text embeddings (after feature extraction) before they are fed into the main diffusion model.